{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import xlrd \n",
    "import numpy\n",
    "import time\n",
    "from sklearn.model_selection import cross_validate\n",
    "from sklearn.model_selection import cross_val_score\n",
    "import numpy as np\n",
    "from sklearn.model_selection import LeaveOneOut\n",
    "import math\n",
    "import numpy.linalg as LA\n",
    "from sklearn.cluster import KMeans\n",
    "import random\n",
    "from sklearn import tree\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn import model_selection,metrics\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import matplotlib.pylab as plt\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from scipy import interp\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "random.sample\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import roc_auc_score, auc\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import classification_report\n",
    "from collections import Counter\n",
    "startTime = time.time()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_data(directory):\n",
    "    GSSM = np.loadtxt(directory + '\\GSSM_.txt',dtype=np.float32)\n",
    "    PESSM = np.loadtxt(directory + '\\PSSM.txt',dtype=np.float32,delimiter='\\t')\n",
    "\n",
    "    IPE = pd.DataFrame(PESSM).reset_index()\n",
    "    IG = pd.DataFrame(GSSM).reset_index()\n",
    "    IPE.rename(columns = {'index':'id'}, inplace = True)\n",
    "    IG.rename(columns = {'index':'id'}, inplace = True)\n",
    "    IPE['id'] = IPE['id']\n",
    "    IG['id'] = IG['id']\n",
    "    \n",
    "    return IPE, IG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(directory, random_seed):\n",
    "    all_associations = pd.read_csv(directory + '/all_gpe_pairs.csv')\n",
    "    known_associations = all_associations.loc[all_associations['label'] == 1]\n",
    "    unknown_associations = all_associations.loc[all_associations['label'] == 0]\n",
    "    random_negative = unknown_associations.sample(n=known_associations.shape[0], random_state=random_seed, axis=0)\n",
    "\n",
    "    sample_df = known_associations.append(random_negative)\n",
    "    sample_df.reset_index(drop=True, inplace=True)\n",
    "\n",
    "    return sample_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def obtain_data(directory, isbalance):\n",
    "\n",
    "    IPE, IG = load_data(directory)\n",
    "\n",
    "    if isbalance:\n",
    "        dtp = sample(directory, random_seed = 1234)\n",
    "    else:\n",
    "        dtp = pd.read_csv(directory + '/all_gene_peco_pairs.csv')\n",
    "\n",
    "    gene_ids = list(set(dtp['gene_idx']))\n",
    "    peco_ids = list(set(dtp['peco_idx']))\n",
    "    random.shuffle(gene_ids)\n",
    "    random.shuffle(peco_ids)\n",
    "    print('# gene = {} | peco = {}'.format(len(gene_ids), len(peco_ids)))\n",
    "\n",
    "    gene_test_num = int(len(gene_ids) / 5)\n",
    "    peco_test_num = int(len(peco_ids) / 5)\n",
    "    print('# Test: gene = {} | peco = {}'.format(gene_test_num, peco_test_num))    \n",
    "    \n",
    "    samples = pd.merge(pd.merge(dtp, IPE, left_on = 'peco_idx', right_on = 'id'), IG, left_on = 'gene_idx', right_on = 'id')\n",
    "    samples.drop(labels = ['id_x', 'id_y'], axis = 1, inplace = True)\n",
    "    \n",
    "    return dtp, gene_ids, peco_ids, gene_test_num, peco_test_num, samples\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def performances(y_true, y_pred, y_prob):\n",
    "\n",
    "    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels = [0, 1]).ravel().tolist()\n",
    "\n",
    "    accuracy = (tp+tn)/(tn+fp+fn+tp)\n",
    "    \n",
    "    if tp+fn == 0:\n",
    "        recall = 0\n",
    "    else:\n",
    "        recall = tp / (tp+fn)\n",
    "    \n",
    "    if tp+fp == 0:\n",
    "        precision = 0\n",
    "    else:\n",
    "        precision = tp / (tp+fp)\n",
    "    \n",
    "    if precision + recall == 0:\n",
    "        f1 = 0\n",
    "    else:\n",
    "        f1 = 2*precision*recall / (precision+recall)\n",
    "    \n",
    "    roc_auc = roc_auc_score(y_true, y_prob)\n",
    "    prec, reca, _ = precision_recall_curve(y_true, y_prob)\n",
    "    aupr = auc(reca, prec)\n",
    "    \n",
    "    print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))\n",
    "    print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))\n",
    "    print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))\n",
    "    print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr))\n",
    "    return (y_true, y_pred, y_prob), (accuracy, precision, recall, f1, roc_auc, aupr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_task_Tp_train_test_idx(samples):\n",
    "    kf = KFold(n_splits = 5, shuffle = True, random_state = 1234)\n",
    "\n",
    "    train_index_all, test_index_all, n = [], [], 0\n",
    "    train_id_all, test_id_all = [], []\n",
    "    fold = 0\n",
    "    for train_idx, test_idx in tqdm(kf.split(samples.iloc[:, 3:])):\n",
    "        print('-------Fold ', fold)\n",
    "        train_index_all.append(train_idx) \n",
    "        test_index_all.append(test_idx)\n",
    "\n",
    "        train_id_all.append(np.array(dtp.iloc[train_idx][['gene_idx', 'peco_idx']]))\n",
    "        test_id_all.append(np.array(dtp.iloc[test_idx][['gene_idx', 'peco_idx']]))\n",
    "\n",
    "        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))\n",
    "        fold += 1\n",
    "    return train_index_all, test_index_all, train_id_all, test_id_all\n",
    "\n",
    "\n",
    "def generate_task_Tg_Tpe_train_test_idx(item, ids, dtp):\n",
    "    \n",
    "    test_num = int(len(ids) / 5)\n",
    "    \n",
    "    train_index_all, test_index_all = [], []\n",
    "    train_id_all, test_id_all = [], []\n",
    "    \n",
    "    for fold in range(5):\n",
    "        print('-------Fold ', fold)\n",
    "        if fold != 4:\n",
    "            test_ids = ids[fold * test_num : (fold + 1) * test_num]\n",
    "        else:\n",
    "            test_ids = ids[fold * test_num :]\n",
    "\n",
    "        train_ids = list(set(ids) ^ set(test_ids))\n",
    "        print('# {}: Train = {} | Test = {}'.format(item, len(train_ids), len(test_ids)))\n",
    "\n",
    "        test_idx = dtp[dtp[item].isin(test_ids)].index.tolist()\n",
    "        train_idx = dtp[dtp[item].isin(train_ids)].index.tolist()\n",
    "        random.shuffle(test_idx)\n",
    "        random.shuffle(train_idx)\n",
    "        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))\n",
    "        assert len(train_idx) + len(test_idx) == len(dtp)\n",
    "\n",
    "        train_index_all.append(train_idx) \n",
    "        test_index_all.append(test_idx)\n",
    "        \n",
    "        train_id_all.append(train_ids)\n",
    "        test_id_all.append(test_ids)\n",
    "        \n",
    "    return train_index_all, test_index_all, train_id_all, test_id_all\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "def transfer(y_pred):\n",
    "    return [[0,1][x>0.5] for x in y_pred.reshape(-1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "def run(train_index_all, test_index_all, samples):\n",
    "    \n",
    "    fold = 0\n",
    "    for train_idx, test_idx in zip(train_index_all, test_index_all):\n",
    "        print('----------------------- Fold = ', str(fold))\n",
    "\n",
    "        X = samples.iloc[:, 3:]\n",
    "        y = samples['label']\n",
    "        x_train, y_train = X.iloc[train_idx], y.iloc[train_idx]\n",
    "        x_test, y_test = X.iloc[test_idx], y.iloc[test_idx]\n",
    "    \n",
    "        GBDT=GradientBoostingClassifier(n_estimators = 12, max_depth = 5, min_samples_leaf = 13)\n",
    "        GBDT.fit(x_train, y_train)\n",
    "        OHE = OneHotEncoder()\n",
    "        OHE.fit(GBDT.apply(x_train)[:, :, 0])\n",
    "        LR = LogisticRegression(n_jobs = -1)\n",
    "        LR.fit(OHE.transform(GBDT.apply(x_train)[:, :, 0]),y_train)\n",
    "        \n",
    "        y_train_prob = LR.predict_proba(OHE.transform(GBDT.apply(x_train)[:, :, 0]))[:,1]\n",
    "        y_test_prob = LR.predict_proba(OHE.transform(GBDT.apply(x_test)[:, :, 0]))[:,1]\n",
    "\n",
    "        y_train_pred, y_test_pred = transfer(y_train_prob), transfer(y_test_prob)\n",
    "\n",
    "        print(len(y_train_pred), len(y_train))\n",
    "        performances_train = performances(y_train, y_train_pred, y_train_prob)\n",
    "        performances_test = performances(y_test, y_test_pred, y_test_prob)\n",
    "\n",
    "        fold += 1\n",
    "    \n",
    "    return y_train_pred, y_test_pred, y_train_prob, y_test_prob, performances_train, performances_test\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "5it [00:00, 161.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# gene = 11177 | peco = 24\n",
      "# Test: gene = 2235 | peco = 4\n",
      "========== isbalance = True | task = Tp\n",
      "-------Fold  0\n",
      "# Pairs: Train = 37692 | Test = 9424\n",
      "-------Fold  1\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  2\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  3\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  4\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "----------------------- Fold =  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37692 37692\n",
      "tn = 15778, fp = 3112, fn = 1988, tp = 16814\n",
      "y_pred: 0 = 17766 | 1 = 19926\n",
      "y_true: 0 = 18890 | 1 = 18802\n",
      "acc=0.8647|precision=0.8438|recall=0.8943|f1=0.8683|auc=0.9466|aupr=0.9481\n",
      "tn = 3895, fp = 773, fn = 514, tp = 4242\n",
      "y_pred: 0 = 4409 | 1 = 5015\n",
      "y_true: 0 = 4668 | 1 = 4756\n",
      "acc=0.8634|precision=0.8459|recall=0.8919|f1=0.8683|auc=0.9416|aupr=0.9437\n",
      "----------------------- Fold =  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37693 37693\n",
      "tn = 15735, fp = 3086, fn = 2035, tp = 16837\n",
      "y_pred: 0 = 17770 | 1 = 19923\n",
      "y_true: 0 = 18821 | 1 = 18872\n",
      "acc=0.8641|precision=0.8451|recall=0.8922|f1=0.8680|auc=0.9459|aupr=0.9476\n",
      "tn = 3957, fp = 780, fn = 524, tp = 4162\n",
      "y_pred: 0 = 4481 | 1 = 4942\n",
      "y_true: 0 = 4737 | 1 = 4686\n",
      "acc=0.8616|precision=0.8422|recall=0.8882|f1=0.8646|auc=0.9429|aupr=0.9425\n",
      "----------------------- Fold =  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37693 37693\n",
      "tn = 15826, fp = 3055, fn = 2042, tp = 16770\n",
      "y_pred: 0 = 17868 | 1 = 19825\n",
      "y_true: 0 = 18881 | 1 = 18812\n",
      "acc=0.8648|precision=0.8459|recall=0.8915|f1=0.8681|auc=0.9468|aupr=0.9485\n",
      "tn = 3895, fp = 782, fn = 545, tp = 4201\n",
      "y_pred: 0 = 4440 | 1 = 4983\n",
      "y_true: 0 = 4677 | 1 = 4746\n",
      "acc=0.8592|precision=0.8431|recall=0.8852|f1=0.8636|auc=0.9415|aupr=0.9421\n",
      "----------------------- Fold =  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37693 37693\n",
      "tn = 15782, fp = 3012, fn = 2075, tp = 16824\n",
      "y_pred: 0 = 17857 | 1 = 19836\n",
      "y_true: 0 = 18794 | 1 = 18899\n",
      "acc=0.8650|precision=0.8482|recall=0.8902|f1=0.8687|auc=0.9466|aupr=0.9488\n",
      "tn = 3970, fp = 794, fn = 553, tp = 4106\n",
      "y_pred: 0 = 4523 | 1 = 4900\n",
      "y_true: 0 = 4764 | 1 = 4659\n",
      "acc=0.8571|precision=0.8380|recall=0.8813|f1=0.8591|auc=0.9397|aupr=0.9389\n",
      "----------------------- Fold =  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37693 37693\n",
      "tn = 15754, fp = 3092, fn = 1987, tp = 16860\n",
      "y_pred: 0 = 17741 | 1 = 19952\n",
      "y_true: 0 = 18846 | 1 = 18847\n",
      "acc=0.8653|precision=0.8450|recall=0.8946|f1=0.8691|auc=0.9466|aupr=0.9482\n",
      "tn = 3879, fp = 833, fn = 508, tp = 4203\n",
      "y_pred: 0 = 4387 | 1 = 5036\n",
      "y_true: 0 = 4712 | 1 = 4711\n",
      "acc=0.8577|precision=0.8346|recall=0.8922|f1=0.8624|auc=0.9411|aupr=0.9418\n",
      "========== isbalance = True | task = Tg\n",
      "-------Fold  0\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37655 | Test = 9461\n",
      "-------Fold  1\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37728 | Test = 9388\n",
      "-------Fold  2\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37803 | Test = 9313\n",
      "-------Fold  3\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37541 | Test = 9575\n",
      "-------Fold  4\n",
      "# gene_idx: Train = 8940 | Test = 2237\n",
      "# Pairs: Train = 37737 | Test = 9379\n",
      "----------------------- Fold =  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37655 37655\n",
      "tn = 15696, fp = 3141, fn = 1912, tp = 16906\n",
      "y_pred: 0 = 17608 | 1 = 20047\n",
      "y_true: 0 = 18837 | 1 = 18818\n",
      "acc=0.8658|precision=0.8433|recall=0.8984|f1=0.8700|auc=0.9472|aupr=0.9485\n",
      "tn = 3900, fp = 821, fn = 552, tp = 4188\n",
      "y_pred: 0 = 4452 | 1 = 5009\n",
      "y_true: 0 = 4721 | 1 = 4740\n",
      "acc=0.8549|precision=0.8361|recall=0.8835|f1=0.8592|auc=0.9380|aupr=0.9403\n",
      "----------------------- Fold =  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37728 37728\n",
      "tn = 15725, fp = 3093, fn = 2033, tp = 16877\n",
      "y_pred: 0 = 17758 | 1 = 19970\n",
      "y_true: 0 = 18818 | 1 = 18910\n",
      "acc=0.8641|precision=0.8451|recall=0.8925|f1=0.8682|auc=0.9464|aupr=0.9486\n",
      "tn = 3943, fp = 797, fn = 490, tp = 4158\n",
      "y_pred: 0 = 4433 | 1 = 4955\n",
      "y_true: 0 = 4740 | 1 = 4648\n",
      "acc=0.8629|precision=0.8392|recall=0.8946|f1=0.8660|auc=0.9425|aupr=0.9420\n",
      "----------------------- Fold =  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37803 37803\n",
      "tn = 15831, fp = 3081, fn = 2051, tp = 16840\n",
      "y_pred: 0 = 17882 | 1 = 19921\n",
      "y_true: 0 = 18912 | 1 = 18891\n",
      "acc=0.8642|precision=0.8453|recall=0.8914|f1=0.8678|auc=0.9460|aupr=0.9473\n",
      "tn = 3862, fp = 784, fn = 486, tp = 4181\n",
      "y_pred: 0 = 4348 | 1 = 4965\n",
      "y_true: 0 = 4646 | 1 = 4667\n",
      "acc=0.8636|precision=0.8421|recall=0.8959|f1=0.8681|auc=0.9443|aupr=0.9460\n",
      "----------------------- Fold =  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37541 37541\n",
      "tn = 15633, fp = 3096, fn = 1929, tp = 16883\n",
      "y_pred: 0 = 17562 | 1 = 19979\n",
      "y_true: 0 = 18729 | 1 = 18812\n",
      "acc=0.8661|precision=0.8450|recall=0.8975|f1=0.8705|auc=0.9469|aupr=0.9486\n",
      "tn = 4004, fp = 825, fn = 546, tp = 4200\n",
      "y_pred: 0 = 4550 | 1 = 5025\n",
      "y_true: 0 = 4829 | 1 = 4746\n",
      "acc=0.8568|precision=0.8358|recall=0.8850|f1=0.8597|auc=0.9402|aupr=0.9409\n",
      "----------------------- Fold =  4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37737 37737\n",
      "tn = 15891, fp = 3045, fn = 2086, tp = 16715\n",
      "y_pred: 0 = 17977 | 1 = 19760\n",
      "y_true: 0 = 18936 | 1 = 18801\n",
      "acc=0.8640|precision=0.8459|recall=0.8890|f1=0.8669|auc=0.9473|aupr=0.9489\n",
      "tn = 3857, fp = 765, fn = 562, tp = 4195\n",
      "y_pred: 0 = 4419 | 1 = 4960\n",
      "y_true: 0 = 4622 | 1 = 4757\n",
      "acc=0.8585|precision=0.8458|recall=0.8819|f1=0.8634|auc=0.9392|aupr=0.9408\n",
      "========== isbalance = True | task = Tpe\n",
      "-------Fold  0\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 42485 | Test = 4631\n",
      "-------Fold  1\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 38907 | Test = 8209\n",
      "-------Fold  2\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 37261 | Test = 9855\n",
      "-------Fold  3\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 32149 | Test = 14967\n",
      "-------Fold  4\n",
      "# peco_idx: Train = 16 | Test = 8\n",
      "# Pairs: Train = 37662 | Test = 9454\n",
      "----------------------- Fold =  0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42485 42485\n",
      "tn = 15792, fp = 3373, fn = 2807, tp = 20513\n",
      "y_pred: 0 = 18599 | 1 = 23886\n",
      "y_true: 0 = 19165 | 1 = 23320\n",
      "acc=0.8545|precision=0.8588|recall=0.8796|f1=0.8691|auc=0.9407|aupr=0.9513\n",
      "tn = 3197, fp = 1196, fn = 169, tp = 69\n",
      "y_pred: 0 = 3366 | 1 = 1265\n",
      "y_true: 0 = 4393 | 1 = 238\n",
      "acc=0.7052|precision=0.0545|recall=0.2899|f1=0.0918|auc=0.5663|aupr=0.0573\n",
      "----------------------- Fold =  1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38907 38907\n",
      "tn = 18005, fp = 1616, fn = 2366, tp = 16920\n",
      "y_pred: 0 = 20371 | 1 = 18536\n",
      "y_true: 0 = 19621 | 1 = 19286\n",
      "acc=0.8977|precision=0.9128|recall=0.8773|f1=0.8947|auc=0.9652|aupr=0.9664\n",
      "tn = 3934, fp = 3, fn = 4262, tp = 10\n",
      "y_pred: 0 = 8196 | 1 = 13\n",
      "y_true: 0 = 3937 | 1 = 4272\n",
      "acc=0.4804|precision=0.7692|recall=0.0023|f1=0.0047|auc=0.5194|aupr=0.5424\n",
      "----------------------- Fold =  2\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37261 37261\n",
      "tn = 17015, fp = 2909, fn = 1759, tp = 15578\n",
      "y_pred: 0 = 18774 | 1 = 18487\n",
      "y_true: 0 = 19924 | 1 = 17337\n",
      "acc=0.8747|precision=0.8426|recall=0.8985|f1=0.8697|auc=0.9546|aupr=0.9523\n",
      "tn = 3633, fp = 1, fn = 6214, tp = 7\n",
      "y_pred: 0 = 9847 | 1 = 8\n",
      "y_true: 0 = 3634 | 1 = 6221\n",
      "acc=0.3694|precision=0.8750|recall=0.0011|f1=0.0022|auc=0.7723|aupr=0.8434\n",
      "----------------------- Fold =  3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32149 32149\n",
      "tn = 17892, fp = 2473, fn = 2646, tp = 9138\n",
      "y_pred: 0 = 20538 | 1 = 11611\n",
      "y_true: 0 = 20365 | 1 = 11784\n",
      "acc=0.8408|precision=0.7870|recall=0.7755|f1=0.7812|auc=0.9191|aupr=0.8679\n",
      "tn = 1222, fp = 1971, fn = 10562, tp = 1212\n",
      "y_pred: 0 = 11784 | 1 = 3183\n",
      "y_true: 0 = 3193 | 1 = 11774\n",
      "acc=0.1626|precision=0.3808|recall=0.1029|f1=0.1621|auc=0.0843|aupr=0.5260\n",
      "----------------------- Fold =  4\n",
      "37662 37662\n",
      "tn = 11987, fp = 3170, fn = 2217, tp = 20288\n",
      "y_pred: 0 = 14204 | 1 = 23458\n",
      "y_true: 0 = 15157 | 1 = 22505\n",
      "acc=0.8570|precision=0.8649|recall=0.9015|f1=0.8828|auc=0.9365|aupr=0.9564\n",
      "tn = 6583, fp = 1818, fn = 197, tp = 856\n",
      "y_pred: 0 = 6780 | 1 = 2674\n",
      "y_true: 0 = 8401 | 1 = 1053\n",
      "acc=0.7869|precision=0.3201|recall=0.8129|f1=0.4594|auc=0.8204|aupr=0.4335\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-37-162246ac2ecc>:2: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred.reshape(-1)]\n"
     ]
    }
   ],
   "source": [
    "directory = '../../data'\n",
    "for isbalance in [True]:\n",
    "    dtp, gene_ids, peco_ids, gene_test_num, peco_test_num, samples = obtain_data(directory,isbalance)  \n",
    "    for task in ['Tp', 'Tg', 'Tpe']:\n",
    "        \n",
    "        print('========== isbalance = {} | task = {}'.format(isbalance, task))\n",
    "        \n",
    "        if task == 'Tp':\n",
    "            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tp_train_test_idx(samples)\n",
    "            \n",
    "        elif task == 'Tg':\n",
    "            item = 'gene_idx'\n",
    "            ids = gene_ids\n",
    "            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, samples)\n",
    "\n",
    "        elif task == 'Tpe':\n",
    "            item = 'peco_idx'\n",
    "            ids = peco_ids\n",
    "            train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, samples)\n",
    "\n",
    "        y_train_pred, y_test_pred, y_train_prob, y_test_prob, performances_train, performances_test = run(train_index_all, test_index_all, samples)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
